Phase Mask Design¶
In this notebook, we will illustrate the problem of inverse design of a phase mask: we will choose the example from Wong et al, 2021, designing a diffractive pupil phase mask for the Toliman telescope.
In order to get high precision centroids, we need to maximize the gradient energy of the pupil; in order to satisfy fabrication constraints, we need a binary mask with phases only in {0, π}.
import warnings
warnings.filterwarnings("ignore") # lots of functions are under development
import jax
from jax.config import config
import jax.numpy as np
from jax import vmap, jit
config.update("jax_enable_x64", True)
import numpy as onp
import matplotlib.pyplot as plt
from tqdm import tqdm
%matplotlib inline
plt.rcParams['image.cmap'] = 'hot'
plt.rcParams["text.usetex"] = 'false'
plt.rcParams['figure.dpi'] = 120
We will first generate an orthonormal basis for the pupil phases, and then threshold this to {0, 1} while preserving soft edges using the Continuous Latent Image Mask Binarization (CLIMB) algorithm from the Wong et al paper.
from sklearn.decomposition import PCA
Generate the support of the pupil:
wf_npix = 256
oversample = 3
nslice = 3
npix = wf_npix * oversample
c = (npix - 1) / 2.
xs = (np.arange(npix, dtype=np.float64) - c) / c
XX, YY = np.meshgrid(xs, xs)
RR = np.sqrt(XX ** 2 + YY ** 2)
PHI = np.arctan2(YY, XX)
mask = np.logical_and(RR <= 1, RR >= 0.175).astype(float)
plt.imshow(mask)
plt.colorbar()
plt.show()
Generate basis vectors however you like - in this case we are using logarithmic radial harmonics and sines and cosines in θ., but you can do whatever you like here. This code is not important; just generate your favourite not-necessarily-orthonormal basis, and we will use PCA to orthonormalize it later on.
# a = 20
# b = 8
# ith = 40
a = 10
b = 8
ith = 10
As = np.arange(-a, a+1)
Bs = 3 * np.arange(0, b+1)
Cs = np.array([-np.pi/2, np.pi/2])
Is = np.arange(-ith, ith+1)
LRHF_fn = lambda A, B, C, RR, PHI: np.cos(A*np.log(RR + 1e-12) + B*PHI + C)
sine_fn = lambda i, RR: np.sin(i * np.pi * RR)
cose_fn = lambda i, RR: np.cos(i * np.pi * RR)
gen_LRHF_basis = vmap(vmap(vmap(LRHF_fn, (None, 0, None, None, None)), (0, None, None, None, None)), (None, None, 0, None, None))
gen_sine_basis = vmap(sine_fn, in_axes=(0, None))
gen_cose_basis = vmap(cose_fn, in_axes=(0, None))
LRHF_basis = gen_LRHF_basis(As, Bs, Cs, RR, PHI).reshape([len(As)*len(Bs)*len(Cs), npix, npix])
sine_basis = gen_sine_basis(Is, RR)
cose_basis = gen_cose_basis(Is, RR)
LRHF_flat = LRHF_basis.reshape([len(As)*len(Bs)*len(Cs), npix*npix])
sine_flat = sine_basis.reshape([len(sine_basis), npix*npix])
cose_flat = cose_basis.reshape([len(cose_basis), npix*npix])
full_basis = np.concatenate([
LRHF_flat,
sine_flat,
cose_flat
])
Orthonormalize with PCA - could also use Gram-Schmidt if you prefer.
%%time
pca = PCA().fit(full_basis)
components = pca.components_.reshape([len(full_basis), npix, npix])
components = np.copy(components[:99,:,:])
basis = np.concatenate([np.mean(pca.mean_)*np.array(np.ones((1,npix,npix))), components])
CPU times: user 1min 46s, sys: 5.12 s, total: 1min 51s Wall time: 21.9 s
Show the pretty basis vectors:
nfigs = 100
ncols = 10
nrows = 1 + nfigs//ncols
plt.figure(figsize=(4*ncols, 4*nrows))
# for i in range(len(components)):
for i in range(nfigs):
plt.subplot(nrows, ncols, i+1)
plt.imshow(basis[i], cmap='seismic')
plt.xticks([])
plt.yticks([])
# plt.colorbar()
plt.tight_layout()